import logging
logger = logging.getLogger()
if not logger.hasHandlers():
    handler = logging.StreamHandler()
    handler.setLevel(logging.INFO)
    logger.addHandler(handler)
logger.setLevel(logging.DEBUG)

# import sys
# sys.path.append('..')

from main import utils
from loss import pytorch_ssim

def compute_quality_metrics(img_tensor, gt_img_tensor, pipe, wm_pipe, threshold=0.9, device="cuda"):
    ssim_value = pytorch_ssim.ssim(img_tensor, gt_img_tensor).item()
    psnr_value = utils.compute_psnr(img_tensor, gt_img_tensor)

    tester_prompt = '' 
    text_embeddings = pipe.get_text_embedding(tester_prompt)
    det_prob = 1 - utils.watermark_prob(img_tensor, pipe, wm_pipe, text_embeddings)
    l1_value = utils.compute_l1(img_tensor, gt_img_tensor)
    lpips_value = utils.compute_lpips(img_tensor, gt_img_tensor, device)
    wm_classification = int(det_prob > threshold)

    # logging.info(f'PSNR: {psnr_value:.4f}, SSIM: {ssim_value:.4f}, L1: {l1_value:.4f}, LPIPS: {lpips_value:.4f}, DetProb: {det_prob:.4f}, WMClass: {wm_classification}')
    return ssim_value, psnr_value, det_prob, l1_value, wm_classification, lpips_value

def adaptive_enhancement(gt_img_tensor, wm_img_tensor, ssim_threshold):
    def binary_search_theta(threshold, lower=0., upper=1., precision=1e-6, max_iter=1000):
        for i in range(max_iter):
            mid_theta = (lower + upper) / 2
            img_tensor = (gt_img_tensor-wm_img_tensor)*mid_theta+wm_img_tensor
            ssim_value = pytorch_ssim.ssim(img_tensor, gt_img_tensor).item()

            if ssim_value <= threshold:
                lower = mid_theta
            else:
                upper = mid_theta
            if upper - lower < precision:
                break
        return lower

    optimal_theta = binary_search_theta(ssim_threshold, precision=0.01)
    logging.info(f'Optimal Theta {optimal_theta}')

    img_tensor = (gt_img_tensor-wm_img_tensor)*optimal_theta+wm_img_tensor

    return img_tensor


# -----------------------------------------------------------------------------------
if __name__ == '__main__':
    import torch
    device = torch.device('cpu')

    gt_img_tensor = utils.get_img_tensor('/home/adityag/ZoDiac/example/input/pepper.tiff', device)
    wm_img_tensor = utils.get_img_tensor('/home/adityag/ZoDiac/example/output/pepper_100.png', device)

    ssim_threshold = 0.92
    enhanced_img_tensor = adaptive_enhancement(gt_img_tensor, wm_img_tensor, ssim_threshold)

    wm_img_tensor_gt = utils.get_img_tensor('/home/adityag/ZoDiac/example/output/pepper_100_SSIM0.92.png', device)
    ssim_value1 = pytorch_ssim.ssim(enhanced_img_tensor, gt_img_tensor).item()
    ssim_value2 = pytorch_ssim.ssim(wm_img_tensor_gt, gt_img_tensor).item()

    print(ssim_value1, ssim_value2)